/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file matmul_macro_v220_intf.h
 * \brief
 */
#ifndef IMPL_MATMUL_MATMUL_MACRO_V220_INTF_H
#define IMPL_MATMUL_MATMUL_MACRO_V220_INTF_H

#include "kernel_operator.h"
#include "matmul_macro_utils.h"

namespace AscendC {

__BLOCK_LOCAL__ __inline__ uint64_t gA2B2PingPongFlag_;
constexpr uint32_t PINGPONG_BUFFER_NUM = 2;

enum madtype {
    F162F32,
    F322F32,
    S82S32,
    S42S32
};

// ===========mad template=================/
// Cmatrix type, Amatrix type, Bmatrix type, L0C_using_uniflag, L0C_using_hset
template <typename IMPL, typename C_T, typename A_T, typename B_T, typename BIAS_T, uint16_t UNIFLAG_EN = 0,
    uint16_t GEMV_MODE = 0, bool L0CACHE = false, bool ISA2B2SHARED = false, typename = void>
class MacroMatmul {
public:
    inline __aicore__ MacroMatmul(){};
    inline __aicore__ ~MacroMatmul();
#ifdef ASCENDC_CPU_DEBUG
    // addr
    uint64_t L0A_PING = L0A_PING_D;
    uint64_t L0A_PONG = L0A_PONG_D;
    uint64_t L0B_PING = L0B_PING_D;
    uint64_t L0B_PONG = L0B_PONG_D;
    uint64_t BIAS_PING = BIAS_PING_D;
    uint64_t BIAS_PONG = BIAS_PONG_D;
#else
    constexpr static uint64_t L0A_PING = L0A_PING_D;
    constexpr static uint64_t L0A_PONG = L0A_PONG_D;
    constexpr static uint64_t L0B_PING = L0B_PING_D;
    constexpr static uint64_t L0B_PONG = L0B_PONG_D;
    constexpr static uint64_t BIAS_PING = BIAS_PING_D;
    constexpr static uint64_t BIAS_PONG = BIAS_PONG_D;
#endif
    // args
    uint16_t sAL1M_;
    uint16_t sAL1K_;
    uint16_t sAL1MOffset_;
    uint16_t sAL1KOffset_;
    uint16_t sBL1N_;
    uint16_t sBL1K_;
    uint16_t sBL1NOffset_;
    uint16_t sBL1KOffset_;
    uint16_t sL1BiasOffset_;
    uint16_t sMadM_;
    uint16_t sMadN_;
    uint16_t sMadK_;
    uint16_t sMad0K_;
    uint16_t sL0cInit_; // 0; normal  1:init
    uint16_t sL0cLast_; // 0; normal  1:last
    uint64_t useL0PingPong_;
    // feature map
    constexpr static uint16_t sFmH_ = 1;
    // state
    uint16_t ssAl0PingPongFlag_;
    uint16_t ssBl0PingPongFlag_;
    constexpr static uint16_t ssBiasFull_ = 0;
    uint16_t ssBiasPingPongFlag_;
    uint16_t kDirectionAlign_;
    // instance args
    // 0:format(M, K)
    // 1:format(K, M), need set transpose
    uint16_t ssAmatrixTranspose_;
    // 0:format(K, N), use load3dv2 carry
    // 1:format(N, K), use load2d carry
    uint16_t ssBmatrixTranspose_;
    // 0: no bias
    // 1: fp16
    // 2: fp32
    uint16_t biasType_;
    constexpr static uint16_t typeSize_ = sizeof(A_T);
    A_T aScalar_;
    A_T bScalar_;
    // tpipe
    TBuf<TPosition::A2> l0aBuf_;
    TBuf<TPosition::B2> l0bBuf_;
    TBuf<TPosition::C2> biasBuf_;
#ifdef ASCENDC_CPU_DEBUG
    uint64_t pA;
    uint64_t pB;
    uint64_t pBias;
    bool initFlag = false;
#endif

    inline __aicore__ void Init() {}
    inline __aicore__ void Release() {}
    inline __aicore__ void ResetCache() {}
    template <bool noBias = false, bool noTail = false, bool intraBlockPartSum = false,
            ScheduleType scheduleType = ScheduleType::INNER_PRODUCT, IterateOrder iterateOrder = IterateOrder::UNDEF>
    inline __aicore__ void Compute(const LocalTensor<A_T> &l1AMatrix, const LocalTensor<B_T> &l1BMatrix,
        const LocalTensor<C_T> &cMatrix, const LocalTensor<BIAS_T> &bias,
	    int64_t offsetb = 0, uint8_t subIdx = 0, uint16_t sMadMStep = 0, uint16_t sMadNStep = 0,
        uint32_t posA = 0, uint32_t posB = 0) {}
    template <bool noBias = false>
    inline __aicore__ void ComputeWithMdb(const LocalTensor<A_T> &l1AMatrix, const LocalTensor<B_T> &l1BMatrix,
        const LocalTensor<C_T> &cMatrix, const LocalTensor<BIAS_T> &bias, uint64_t kC0Tail, uint64_t kTail,
        uint16_t sMadMStep) {}
    template <bool noBias = false>
    inline __aicore__ void ComputeWithNdb(const LocalTensor<A_T> &l1AMatrix, const LocalTensor<B_T> &l1BMatrix,
        const LocalTensor<C_T> &cMatrix, const LocalTensor<BIAS_T> &bias, uint64_t kC0Tail, uint64_t kTail,
        uint16_t sMadNStep) {}
    inline __aicore__ void InitSetFlag() {}
    inline __aicore__ void LoadL12L0BFullLoad(const LocalTensor<B_T> &l1B, uint8_t subIdx,
        uint16_t sMad0K, uint16_t sMadN, uint16_t sBL1N, uint16_t sBL1NOffset, uint16_t sBL1KOffset,
	uint16_t offset) {}
    inline __aicore__ constexpr static uint16_t GetHwK0()
    {
        return 0;
    }
    inline __aicore__ constexpr static madtype GetMode()
    {
        return F162F32;
    }
    constexpr static madtype mode_ = GetMode();
private:
    inline __aicore__ void LoadL12L0A(uint64_t k_inner, uint64_t aPoskPtr, uint16_t usedK,
        const LocalTensor<A_T> &l1A, LocalTensor<A_T> &l0A) {}
    inline __aicore__ void LoadL12L0B(uint64_t k_inner, uint16_t usedK,
        const LocalTensor<B_T> &l1B, LocalTensor<B_T> &l0B) {}
    inline __aicore__ void MmadMacro(const LocalTensor<A_T> &l0A, const LocalTensor<B_T> &l0B,
        const LocalTensor<C_T> &cMatrix, uint16_t mmadK, uint8_t unitFlag, bool l0c_initial) {}
};

} // namespace AscendC
#endif